#!/bin/bash

# SD3训练脚本 - 直接传参
# 使用方法: ./train_sd3.sh

# 16GB显存优化参数
python sd_scripts/sd3_train_network.py \
    --pretrained_model_name_or_path sd3.5_large.safetensors \
    --clip_l clip_l.safetensors \
    --clip_g clip_g.safetensors \
    --t5xxl t5xxl_fp16.safetensors \
    --cache_latents_to_disk \
    --save_model_as safetensors \
    --sdpa \
    --persistent_data_loader_workers \
    --max_data_loader_n_workers 1 \
    --seed 42 \
    --gradient_checkpointing \
    --mixed_precision bf16 \
    --save_precision bf16 \
    --network_module networks.lora_sd3 \
    --network_dim 32 \
    --network_train_unet_only \
    --optimizer_type AdaFactor \
    --learning_rate 1e-4 \
    --cache_text_encoder_outputs \
    --cache_text_encoder_outputs_to_disk \
    --highvram \
    --max_train_steps 200 \
    --save_every_n_epochs 1 \
    --dataset_config dataset.toml \
    --output_dir ./output \
    --output_name sd3_lora \
    --blocks_to_swap 8 \
    --disable_mmap_load_safetensors \
    --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" \
    --lr_scheduler constant_with_warmup \
    --max_grad_norm 0.0 